# -*- coding: utf-8 -*- #
"""
Time                2023/5/5 11:18
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                prometheus_metric_reader.py
Description:
"""
import requests
from typing import List
from enum import Enum
from urllib.parse import urljoin
from .result import *
from .metric_reader import MetricReader
from .url import MetricReaderUrl
from .task import RangeQueryTask, QueryTask, InstantQueryTask
from .filter import FilterType
from .common import StaticConst

GET_LABEL_NAMES = "/api/v1/series"
RANGE_QUERY_API = "/api/v1/query_range"
QUERY_API = "/api/v1/query"
METRIC_METADATA = "/api/v1/metadata"
GET_LABEL_VALUE = "/api/v1/label"

class QueryType(Enum):
    instant = 0
    range = 1

class PrometheusMetricReader(MetricReader):
    def __init__(self, url: MetricReaderUrl, **kwargs):
        super().__init__(url, **kwargs)
        protocol = "http"
        if self.get_special_param(StaticConst.TLS):
            protocol = "https"
        self.base_url = f"{protocol}://{url.netloc}"

    def _get_url(self, api: str):
        return urljoin(self.base_url, api)
    
    def _get_basic_promql_query(self, task: QueryTask):
        promql_str = task.metric_name
        rules = []
        if task.filters is not None and len(task.filters) > 0:
            for flt in task.filters:
                if flt.filter_type == FilterType.Equal:
                    rules.append(f'{flt.label_name}="{flt.value}"')
                elif flt.filter_type == FilterType.Wildcard:
                    rules.append(f'{flt.label_name}=~'
                                 f'"{flt.value.replace("*", "(.*?)")}"')
            promql_str = promql_str + "{" + ",".join(rules) + "}"
        return promql_str

    def _parse_res(self, query_type, res) -> MetricResult:
        mr = MetricResult(0, [])
        if res.status_code != 200:
            mr.code = 1
            mr.err_msg = "Request failed, status_code != 200"
        else:
            json_res = res.json()
            if json_res["status"] != "success":
                mr.code = 1
                mr.err_msg = f"Prometheus API error: {json_res['error']}"
            else:
                if query_type == QueryType.range:
                    mr.data = [
                    RangeVectorResult(item["metric"]["__name__"],
                                      item["metric"],
                                      values=item["values"])
                    for item in json_res["data"]["result"]
                    ]
                elif query_type == QueryType.instant:
                    for item in json_res["data"]["result"]:
                        metric_name = ""
                        if "__name__" in item["metric"]:
                            metric_name = item["metric"]["__name__"]
                        mr.data.append(
                            InstantVectorResult(metric_name,
                                                item["metric"],
                                                value=item["value"])
                        )
   
        return mr            

    def get_metric_names(self, limit: int = -1) -> MetricResult:
        params = {}
        if limit > 0:
            params["limit"] = limit
        res = requests.get(self._get_url(METRIC_METADATA), params)
        mr = MetricResult(0, [])
        if res.status_code != 200:
            mr.code = 1
            mr.err_msg = "Request failed, status_code != 200"
        else:
            json_res = res.json()
            if json_res["status"] != "success":
                mr.code = 1
                mr.err_msg = f"Prometheus API error: {json_res['error']}"
            else:
                for k in json_res["data"]:
                    mr.data.append(k)
        return mr

    def get_metric_labels(self, metric_name) -> MetricResult:
        res = requests.get(self._get_url(GET_LABEL_NAMES), {
            "match[]": metric_name
        })
        mr = MetricResult(0, [])
        if res.status_code != 200:
            mr.code = 1
            mr.err_msg = f"Get metric labels for {metric_name} failed!"
        else:
            json_res = res.json()
            if json_res["status"] != "success":
                mr.code = 1
                mr.err_msg = (f"Get metric labels for"
                              f" {metric_name} failed => {json_res['error']}")
            else:
                series = json_res["data"]
                if len(series) > 0:
                    mr.data = list(series[0].keys())
        return mr
    
    def get_label_values(self, label_name) -> MetricResult:
        url = GET_LABEL_VALUE + "/" + label_name + "/values"
        res = requests.get(self._get_url(url))
        mr = MetricResult(0, [])
        if res.status_code != 200:
            mr.code = 1
            mr.err_msg = f"Get values for {label_name} failed!"
        else:
            json_res = res.json()
            if json_res["status"] != "success":
                mr.code = 1
                mr.err_msg = (f"Get label values for"
                              f" {label_name} failed => {json_res['error']}")
            else:
                values = json_res["data"]
                if len(values) > 0:
                    mr.data = values
        return mr

    def instant_query(self, queries: List[InstantQueryTask]) -> MetricResult:
        def query_one(task: InstantQueryTask):
            basic_query = self._get_basic_promql_query(task)
            """
            range vector aggregation function also need to use query api
            example: curl -g 'http://localhost:9090/api/v1/query?
            query=avg_over_time(sysom_cgroups[5m])&time=1696679657.796'
            """
            if task.aggregation is not None:
                if task.interval is not None:
                    interval = f"[{task.interval}]"
                    basic_query = basic_query + interval
                basic_query = task.aggregation + "(" + basic_query + ")"
                if task.clause_label is not None:
                    basic_query = basic_query + task.clause \
                        + "(" + ",".join(task.clause_label) + ")"
                    
            res = requests.get(self._get_url(QUERY_API), {
                "query": basic_query,
                "time": task.time,
            })

            return self._parse_res(QueryType.instant, res) 

        merged_result = MetricResult(0, [])
        for query in queries:
            mr_result = query_one(query)
            if mr_result.code != 0:
                merged_result.code = mr_result.code
                merged_result.err_msg = mr_result.err_msg
                break
            merged_result.data.extend(mr_result.data)
        return merged_result    

    def range_query(self, queries: List[RangeQueryTask]) -> MetricResult:
        def query_one(task: RangeQueryTask):
            res = requests.get(self._get_url(RANGE_QUERY_API), {
                "query": self._get_basic_promql_query(task),
                "start": task.start_time,
                "end": task.end_time,
                "step": f"{task.step}s"
            })

            return self._parse_res(QueryType.range, res)

        merged_result = MetricResult(0, [])
        for query in queries:
            mr_result = query_one(query)
            if mr_result.code != 0:
                merged_result.code = mr_result.code
                merged_result.err_msg = mr_result.err_msg
                break
            merged_result.data.extend(mr_result.data)
        return merged_result
